In [1]:
%matplotlib inline

DCGAN Tutorial

Author: Nathan Inkawhich <https://github.com/inkawhich>__

In [5]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
Random Seed:  999
Out[5]:
<torch._C.Generator at 0x7fbcac03cd70>

Inputs

Let’s define some inputs for the run:

  • dataroot - the path to the root of the dataset folder. We will talk more about the dataset in the next section
  • workers - the number of worker threads for loading the data with the DataLoader
  • batch_size - the batch size used in training. The DCGAN paper uses a batch size of 128
  • image_size - the spatial size of the images used for training. This implementation defaults to 64x64. If another size is desired, the structures of D and G must be changed. See here <https://github.com/pytorch/examples/issues/70>__ for more details
  • nc - number of color channels in the input images. For color images this is 3
  • nz - length of latent vector
  • ngf - relates to the depth of feature maps carried through the generator
  • ndf - sets the depth of feature maps propagated through the discriminator
  • num_epochs - number of training epochs to run. Training for longer will probably lead to better results but will also take much longer
  • lr - learning rate for training. As described in the DCGAN paper, this number should be 0.0002
  • beta1 - beta1 hyperparameter for Adam optimizers. As described in paper, this number should be 0.5
  • ngpu - number of GPUs available. If this is 0, code will run in CPU mode. If this number is greater than 0 it will run on that number of GPUs
In [6]:
# Root directory for dataset
dataroot = "../data/raw/planctons_original"

# Number of workers for dataloader
workers = 10

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 200

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

Data

In [7]:
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
Out[7]:
<matplotlib.image.AxesImage at 0x7fbc5050de20>

Implementation

With our input parameters set and the dataset prepared, we can now get into the implementation. We will start with the weigth initialization strategy, then talk about the generator, discriminator, loss functions, and training loop in detail.

Weight Initialization ~~~~~

From the DCGAN paper, the authors specify that all model weights shall be randomly initialized from a Normal distribution with mean=0, stdev=0.02. The weights_init function takes an initialized model as input and reinitializes all convolutional, convolutional-transpose, and batch normalization layers to meet this criteria. This function is applied to the models immediately after initialization.

In [8]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
In [9]:
# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

Now, we can instantiate the generator and apply the weights_init function. Check out the printed model to see how the generator object is structured.

In [10]:
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

Discriminator Code

In [11]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

Now, as with the generator, we can create the discriminator, apply the weights_init function, and print the model’s structure.

In [12]:
# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
    
# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)
In [13]:
# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
In [14]:
%%time
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs): 
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
Starting Training Loop...
[0/200][0/29]	Loss_D: 2.0496	Loss_G: 5.4277	D(x): 0.5714	D(G(z)): 0.6765 / 0.0087
[1/200][0/29]	Loss_D: 4.2876	Loss_G: 25.6962	D(x): 0.2498	D(G(z)): 0.0000 / 0.0000
[2/200][0/29]	Loss_D: 0.3011	Loss_G: 13.0003	D(x): 0.9983	D(G(z)): 0.2440 / 0.0000
[3/200][0/29]	Loss_D: 0.0893	Loss_G: 6.2853	D(x): 0.9692	D(G(z)): 0.0517 / 0.0022
[4/200][0/29]	Loss_D: 0.0862	Loss_G: 6.3527	D(x): 0.9555	D(G(z)): 0.0378 / 0.0027
[5/200][0/29]	Loss_D: 2.2840	Loss_G: 0.9106	D(x): 0.1607	D(G(z)): 0.0117 / 0.5049
[6/200][0/29]	Loss_D: 0.7240	Loss_G: 2.6944	D(x): 0.7996	D(G(z)): 0.3739 / 0.0758
[7/200][0/29]	Loss_D: 1.3219	Loss_G: 3.1662	D(x): 0.8770	D(G(z)): 0.6689 / 0.0533
[8/200][0/29]	Loss_D: 0.5178	Loss_G: 3.2181	D(x): 0.7850	D(G(z)): 0.2041 / 0.0451
[9/200][0/29]	Loss_D: 0.5658	Loss_G: 3.5959	D(x): 0.7374	D(G(z)): 0.1417 / 0.0478
[10/200][0/29]	Loss_D: 0.1529	Loss_G: 5.5031	D(x): 0.9425	D(G(z)): 0.0847 / 0.0059
[11/200][0/29]	Loss_D: 0.1472	Loss_G: 4.9659	D(x): 0.9726	D(G(z)): 0.1095 / 0.0111
[12/200][0/29]	Loss_D: 0.8499	Loss_G: 3.6523	D(x): 0.5766	D(G(z)): 0.1264 / 0.0637
[13/200][0/29]	Loss_D: 2.9623	Loss_G: 1.9975	D(x): 0.0711	D(G(z)): 0.0012 / 0.1582
[14/200][0/29]	Loss_D: 0.8611	Loss_G: 4.6825	D(x): 0.9437	D(G(z)): 0.5062 / 0.0271
[15/200][0/29]	Loss_D: 0.3578	Loss_G: 3.7183	D(x): 0.8038	D(G(z)): 0.1064 / 0.0320
[16/200][0/29]	Loss_D: 0.5969	Loss_G: 2.1818	D(x): 0.6030	D(G(z)): 0.0358 / 0.1835
[17/200][0/29]	Loss_D: 0.7226	Loss_G: 2.2538	D(x): 0.6480	D(G(z)): 0.2176 / 0.1144
[18/200][0/29]	Loss_D: 0.7124	Loss_G: 2.1017	D(x): 0.7105	D(G(z)): 0.2656 / 0.1547
[19/200][0/29]	Loss_D: 0.6952	Loss_G: 3.4630	D(x): 0.7231	D(G(z)): 0.2584 / 0.0444
[20/200][0/29]	Loss_D: 0.7541	Loss_G: 2.8725	D(x): 0.7874	D(G(z)): 0.3621 / 0.0715
[21/200][0/29]	Loss_D: 0.4743	Loss_G: 3.4121	D(x): 0.8086	D(G(z)): 0.1935 / 0.0460
[22/200][0/29]	Loss_D: 0.9072	Loss_G: 3.5397	D(x): 0.8597	D(G(z)): 0.4949 / 0.0368
[23/200][0/29]	Loss_D: 0.8678	Loss_G: 6.5546	D(x): 0.8460	D(G(z)): 0.4185 / 0.0036
[24/200][0/29]	Loss_D: 1.1546	Loss_G: 7.6662	D(x): 0.9708	D(G(z)): 0.6342 / 0.0013
[25/200][0/29]	Loss_D: 0.7224	Loss_G: 4.0446	D(x): 0.9269	D(G(z)): 0.4452 / 0.0254
[26/200][0/29]	Loss_D: 0.7146	Loss_G: 4.7036	D(x): 0.8011	D(G(z)): 0.3490 / 0.0133
[27/200][0/29]	Loss_D: 0.5026	Loss_G: 2.7392	D(x): 0.6865	D(G(z)): 0.0753 / 0.0865
[28/200][0/29]	Loss_D: 0.5288	Loss_G: 4.1984	D(x): 0.8457	D(G(z)): 0.2497 / 0.0267
[29/200][0/29]	Loss_D: 0.5091	Loss_G: 3.1260	D(x): 0.8348	D(G(z)): 0.2458 / 0.0694
[30/200][0/29]	Loss_D: 0.7263	Loss_G: 4.1690	D(x): 0.7972	D(G(z)): 0.3587 / 0.0205
[31/200][0/29]	Loss_D: 0.5474	Loss_G: 1.6678	D(x): 0.7032	D(G(z)): 0.1167 / 0.2189
[32/200][0/29]	Loss_D: 0.8779	Loss_G: 3.8331	D(x): 0.8599	D(G(z)): 0.4630 / 0.0410
[33/200][0/29]	Loss_D: 0.3827	Loss_G: 4.0054	D(x): 0.8420	D(G(z)): 0.1605 / 0.0297
[34/200][0/29]	Loss_D: 0.4910	Loss_G: 3.7995	D(x): 0.8493	D(G(z)): 0.2617 / 0.0265
[35/200][0/29]	Loss_D: 0.4110	Loss_G: 3.7146	D(x): 0.8592	D(G(z)): 0.2123 / 0.0313
[36/200][0/29]	Loss_D: 0.6831	Loss_G: 1.6835	D(x): 0.6294	D(G(z)): 0.1386 / 0.2297
[37/200][0/29]	Loss_D: 0.4561	Loss_G: 3.2626	D(x): 0.8204	D(G(z)): 0.2044 / 0.0524
[38/200][0/29]	Loss_D: 0.5376	Loss_G: 2.8614	D(x): 0.7359	D(G(z)): 0.1699 / 0.0736
[39/200][0/29]	Loss_D: 0.4391	Loss_G: 4.1306	D(x): 0.9085	D(G(z)): 0.2690 / 0.0198
[40/200][0/29]	Loss_D: 0.4263	Loss_G: 3.9555	D(x): 0.9591	D(G(z)): 0.3002 / 0.0260
[41/200][0/29]	Loss_D: 0.4536	Loss_G: 2.5445	D(x): 0.7456	D(G(z)): 0.1114 / 0.1030
[42/200][0/29]	Loss_D: 0.4235	Loss_G: 3.1432	D(x): 0.8179	D(G(z)): 0.1677 / 0.0600
[43/200][0/29]	Loss_D: 0.9452	Loss_G: 5.3538	D(x): 0.9627	D(G(z)): 0.5445 / 0.0071
[44/200][0/29]	Loss_D: 0.5101	Loss_G: 2.7613	D(x): 0.6528	D(G(z)): 0.0330 / 0.0897
[45/200][0/29]	Loss_D: 0.5667	Loss_G: 3.9558	D(x): 0.8814	D(G(z)): 0.3227 / 0.0272
[46/200][0/29]	Loss_D: 0.9570	Loss_G: 4.9540	D(x): 0.9824	D(G(z)): 0.5669 / 0.0114
[47/200][0/29]	Loss_D: 0.4039	Loss_G: 2.9427	D(x): 0.7446	D(G(z)): 0.0832 / 0.0848
[48/200][0/29]	Loss_D: 0.4822	Loss_G: 4.7272	D(x): 0.8849	D(G(z)): 0.2524 / 0.0161
[49/200][0/29]	Loss_D: 0.5154	Loss_G: 3.4500	D(x): 0.9270	D(G(z)): 0.3257 / 0.0459
[50/200][0/29]	Loss_D: 0.8098	Loss_G: 4.2288	D(x): 0.8982	D(G(z)): 0.4578 / 0.0215
[51/200][0/29]	Loss_D: 1.4680	Loss_G: 6.2739	D(x): 0.9773	D(G(z)): 0.7107 / 0.0030
[52/200][0/29]	Loss_D: 0.7641	Loss_G: 4.6987	D(x): 0.8587	D(G(z)): 0.4175 / 0.0132
[53/200][0/29]	Loss_D: 0.3469	Loss_G: 2.7661	D(x): 0.8207	D(G(z)): 0.1228 / 0.0789
[54/200][0/29]	Loss_D: 1.3513	Loss_G: 5.8061	D(x): 0.9793	D(G(z)): 0.6565 / 0.0071
[55/200][0/29]	Loss_D: 0.3104	Loss_G: 4.0616	D(x): 0.8216	D(G(z)): 0.0919 / 0.0370
[56/200][0/29]	Loss_D: 0.5714	Loss_G: 3.1832	D(x): 0.6185	D(G(z)): 0.0198 / 0.0725
[57/200][0/29]	Loss_D: 0.7300	Loss_G: 1.6991	D(x): 0.6054	D(G(z)): 0.1019 / 0.2645
[58/200][0/29]	Loss_D: 0.3914	Loss_G: 2.5967	D(x): 0.7843	D(G(z)): 0.1140 / 0.1126
[59/200][0/29]	Loss_D: 0.3479	Loss_G: 3.3952	D(x): 0.8095	D(G(z)): 0.1040 / 0.0483
[60/200][0/29]	Loss_D: 0.4284	Loss_G: 4.4067	D(x): 0.9158	D(G(z)): 0.2648 / 0.0169
[61/200][0/29]	Loss_D: 0.2916	Loss_G: 3.4287	D(x): 0.8918	D(G(z)): 0.1458 / 0.0496
[62/200][0/29]	Loss_D: 0.5677	Loss_G: 2.9326	D(x): 0.8905	D(G(z)): 0.3228 / 0.0788
[63/200][0/29]	Loss_D: 0.6134	Loss_G: 2.4610	D(x): 0.6231	D(G(z)): 0.0597 / 0.1172
[64/200][0/29]	Loss_D: 0.2623	Loss_G: 3.1199	D(x): 0.9439	D(G(z)): 0.1741 / 0.0644
[65/200][0/29]	Loss_D: 1.1991	Loss_G: 1.6009	D(x): 0.3687	D(G(z)): 0.0262 / 0.2654
[66/200][0/29]	Loss_D: 0.7422	Loss_G: 5.5751	D(x): 0.9497	D(G(z)): 0.4627 / 0.0058
[67/200][0/29]	Loss_D: 0.4565	Loss_G: 3.8843	D(x): 0.6969	D(G(z)): 0.0239 / 0.0415
[68/200][0/29]	Loss_D: 0.8007	Loss_G: 5.4493	D(x): 0.9054	D(G(z)): 0.4591 / 0.0067
[69/200][0/29]	Loss_D: 0.6669	Loss_G: 3.9097	D(x): 0.8368	D(G(z)): 0.3421 / 0.0293
[70/200][0/29]	Loss_D: 0.7063	Loss_G: 4.2685	D(x): 0.8334	D(G(z)): 0.3738 / 0.0198
[71/200][0/29]	Loss_D: 0.6572	Loss_G: 2.9428	D(x): 0.8804	D(G(z)): 0.3723 / 0.0757
[72/200][0/29]	Loss_D: 0.4174	Loss_G: 2.7219	D(x): 0.9418	D(G(z)): 0.2558 / 0.0839
[73/200][0/29]	Loss_D: 0.7115	Loss_G: 4.4713	D(x): 0.8897	D(G(z)): 0.4102 / 0.0184
[74/200][0/29]	Loss_D: 0.6028	Loss_G: 2.2723	D(x): 0.6216	D(G(z)): 0.0629 / 0.1468
[75/200][0/29]	Loss_D: 0.5563	Loss_G: 5.1553	D(x): 0.9362	D(G(z)): 0.3594 / 0.0086
[76/200][0/29]	Loss_D: 1.0016	Loss_G: 1.6670	D(x): 0.4265	D(G(z)): 0.0122 / 0.2555
[77/200][0/29]	Loss_D: 0.2427	Loss_G: 3.6158	D(x): 0.8555	D(G(z)): 0.0720 / 0.0366
[78/200][0/29]	Loss_D: 0.2694	Loss_G: 3.3786	D(x): 0.8747	D(G(z)): 0.1165 / 0.0483
[79/200][0/29]	Loss_D: 0.2493	Loss_G: 3.7566	D(x): 0.8945	D(G(z)): 0.1117 / 0.0345
[80/200][0/29]	Loss_D: 0.3370	Loss_G: 4.1660	D(x): 0.7929	D(G(z)): 0.0726 / 0.0267
[81/200][0/29]	Loss_D: 0.3224	Loss_G: 3.5349	D(x): 0.8616	D(G(z)): 0.1381 / 0.0433
[82/200][0/29]	Loss_D: 0.2955	Loss_G: 3.3791	D(x): 0.8748	D(G(z)): 0.1347 / 0.0451
[83/200][0/29]	Loss_D: 0.5670	Loss_G: 3.7841	D(x): 0.6302	D(G(z)): 0.0177 / 0.0449
[84/200][0/29]	Loss_D: 0.2506	Loss_G: 3.9737	D(x): 0.8464	D(G(z)): 0.0633 / 0.0317
[85/200][0/29]	Loss_D: 0.2872	Loss_G: 2.9568	D(x): 0.8615	D(G(z)): 0.1142 / 0.0698
[86/200][0/29]	Loss_D: 0.3689	Loss_G: 2.7059	D(x): 0.7692	D(G(z)): 0.0745 / 0.0957
[87/200][0/29]	Loss_D: 0.5402	Loss_G: 1.9860	D(x): 0.6419	D(G(z)): 0.0356 / 0.2136
[88/200][0/29]	Loss_D: 0.3196	Loss_G: 3.7599	D(x): 0.8548	D(G(z)): 0.1203 / 0.0376
[89/200][0/29]	Loss_D: 0.4026	Loss_G: 2.7127	D(x): 0.7800	D(G(z)): 0.1117 / 0.0927
[90/200][0/29]	Loss_D: 0.3601	Loss_G: 4.0378	D(x): 0.9564	D(G(z)): 0.2485 / 0.0266
[91/200][0/29]	Loss_D: 0.2515	Loss_G: 2.9888	D(x): 0.8320	D(G(z)): 0.0515 / 0.0757
[92/200][0/29]	Loss_D: 0.3404	Loss_G: 3.8484	D(x): 0.9341	D(G(z)): 0.2134 / 0.0376
[93/200][0/29]	Loss_D: 0.6964	Loss_G: 1.0576	D(x): 0.6028	D(G(z)): 0.0988 / 0.4150
[94/200][0/29]	Loss_D: 0.6395	Loss_G: 5.8440	D(x): 0.9440	D(G(z)): 0.3943 / 0.0042
[95/200][0/29]	Loss_D: 0.3331	Loss_G: 3.7265	D(x): 0.7653	D(G(z)): 0.0351 / 0.0414
[96/200][0/29]	Loss_D: 0.5095	Loss_G: 2.3112	D(x): 0.7385	D(G(z)): 0.1503 / 0.1438
[97/200][0/29]	Loss_D: 0.3365	Loss_G: 4.7906	D(x): 0.9267	D(G(z)): 0.2091 / 0.0119
[98/200][0/29]	Loss_D: 0.3815	Loss_G: 3.7188	D(x): 0.7703	D(G(z)): 0.0735 / 0.0396
[99/200][0/29]	Loss_D: 0.5143	Loss_G: 5.0483	D(x): 0.9308	D(G(z)): 0.3261 / 0.0098
[100/200][0/29]	Loss_D: 0.3817	Loss_G: 3.2539	D(x): 0.9360	D(G(z)): 0.2433 / 0.0597
[101/200][0/29]	Loss_D: 1.1037	Loss_G: 6.3981	D(x): 0.9843	D(G(z)): 0.5828 / 0.0049
[102/200][0/29]	Loss_D: 0.4680	Loss_G: 4.6143	D(x): 0.9520	D(G(z)): 0.3228 / 0.0141
[103/200][0/29]	Loss_D: 0.3327	Loss_G: 3.1352	D(x): 0.7873	D(G(z)): 0.0647 / 0.0638
[104/200][0/29]	Loss_D: 0.2936	Loss_G: 3.0109	D(x): 0.8509	D(G(z)): 0.0994 / 0.0692
[105/200][0/29]	Loss_D: 0.4057	Loss_G: 4.1849	D(x): 0.8990	D(G(z)): 0.2380 / 0.0219
[106/200][0/29]	Loss_D: 0.2802	Loss_G: 4.1243	D(x): 0.9699	D(G(z)): 0.2077 / 0.0258
[107/200][0/29]	Loss_D: 0.2500	Loss_G: 4.0079	D(x): 0.8557	D(G(z)): 0.0725 / 0.0362
[108/200][0/29]	Loss_D: 1.0960	Loss_G: 1.3010	D(x): 0.4173	D(G(z)): 0.0272 / 0.3611
[109/200][0/29]	Loss_D: 0.3803	Loss_G: 4.5367	D(x): 0.8962	D(G(z)): 0.2097 / 0.0182
[110/200][0/29]	Loss_D: 0.3492	Loss_G: 3.9755	D(x): 0.8844	D(G(z)): 0.1839 / 0.0295
[111/200][0/29]	Loss_D: 0.3885	Loss_G: 4.2899	D(x): 0.9030	D(G(z)): 0.2307 / 0.0207
[112/200][0/29]	Loss_D: 0.3561	Loss_G: 3.7313	D(x): 0.7728	D(G(z)): 0.0572 / 0.0474
[113/200][0/29]	Loss_D: 0.7096	Loss_G: 1.0246	D(x): 0.5761	D(G(z)): 0.0689 / 0.4702
[114/200][0/29]	Loss_D: 0.6662	Loss_G: 5.0370	D(x): 0.9764	D(G(z)): 0.4075 / 0.0119
[115/200][0/29]	Loss_D: 0.7960	Loss_G: 1.9850	D(x): 0.5622	D(G(z)): 0.0639 / 0.2129
[116/200][0/29]	Loss_D: 0.5462	Loss_G: 3.3335	D(x): 0.8498	D(G(z)): 0.2824 / 0.0507
[117/200][0/29]	Loss_D: 0.2578	Loss_G: 3.9922	D(x): 0.8994	D(G(z)): 0.1306 / 0.0267
[118/200][0/29]	Loss_D: 0.3090	Loss_G: 3.6995	D(x): 0.9747	D(G(z)): 0.2252 / 0.0352
[119/200][0/29]	Loss_D: 1.1299	Loss_G: 5.3906	D(x): 0.9829	D(G(z)): 0.5903 / 0.0114
[120/200][0/29]	Loss_D: 0.3311	Loss_G: 4.2024	D(x): 0.9587	D(G(z)): 0.2303 / 0.0219
[121/200][0/29]	Loss_D: 0.2622	Loss_G: 3.1597	D(x): 0.8327	D(G(z)): 0.0615 / 0.0665
[122/200][0/29]	Loss_D: 0.2289	Loss_G: 2.4934	D(x): 0.8626	D(G(z)): 0.0710 / 0.1091
[123/200][0/29]	Loss_D: 0.3967	Loss_G: 4.4024	D(x): 0.7482	D(G(z)): 0.0319 / 0.0355
[124/200][0/29]	Loss_D: 1.6501	Loss_G: 7.1195	D(x): 0.9983	D(G(z)): 0.7347 / 0.0023
[125/200][0/29]	Loss_D: 0.5998	Loss_G: 1.7626	D(x): 0.6508	D(G(z)): 0.0827 / 0.2181
[126/200][0/29]	Loss_D: 0.3701	Loss_G: 4.1205	D(x): 0.9295	D(G(z)): 0.2302 / 0.0226
[127/200][0/29]	Loss_D: 0.2658	Loss_G: 3.5970	D(x): 0.9246	D(G(z)): 0.1569 / 0.0426
[128/200][0/29]	Loss_D: 0.3830	Loss_G: 3.2428	D(x): 0.7625	D(G(z)): 0.0758 / 0.0687
[129/200][0/29]	Loss_D: 0.2601	Loss_G: 3.2132	D(x): 0.9072	D(G(z)): 0.1399 / 0.0564
[130/200][0/29]	Loss_D: 0.4376	Loss_G: 2.5952	D(x): 0.7958	D(G(z)): 0.1609 / 0.1282
[131/200][0/29]	Loss_D: 0.9269	Loss_G: 2.4691	D(x): 0.4850	D(G(z)): 0.0073 / 0.1778
[132/200][0/29]	Loss_D: 0.3672	Loss_G: 2.6825	D(x): 0.7670	D(G(z)): 0.0702 / 0.1006
[133/200][0/29]	Loss_D: 0.5517	Loss_G: 4.5020	D(x): 0.9705	D(G(z)): 0.3446 / 0.0193
[134/200][0/29]	Loss_D: 0.2822	Loss_G: 4.2797	D(x): 0.9425	D(G(z)): 0.1850 / 0.0232
[135/200][0/29]	Loss_D: 0.3999	Loss_G: 4.9715	D(x): 0.9167	D(G(z)): 0.2309 / 0.0122
[136/200][0/29]	Loss_D: 0.2145	Loss_G: 3.7001	D(x): 0.9641	D(G(z)): 0.1429 / 0.0413
[137/200][0/29]	Loss_D: 0.6803	Loss_G: 1.3819	D(x): 0.5780	D(G(z)): 0.0412 / 0.3743
[138/200][0/29]	Loss_D: 0.7697	Loss_G: 5.5670	D(x): 0.9379	D(G(z)): 0.4313 / 0.0104
[139/200][0/29]	Loss_D: 0.3736	Loss_G: 4.0882	D(x): 0.9144	D(G(z)): 0.2248 / 0.0285
[140/200][0/29]	Loss_D: 0.2976	Loss_G: 3.7616	D(x): 0.9899	D(G(z)): 0.2306 / 0.0330
[141/200][0/29]	Loss_D: 0.4253	Loss_G: 1.6148	D(x): 0.7243	D(G(z)): 0.0648 / 0.2629
[142/200][0/29]	Loss_D: 0.2587	Loss_G: 4.1045	D(x): 0.8236	D(G(z)): 0.0415 / 0.0253
[143/200][0/29]	Loss_D: 0.2976	Loss_G: 3.9097	D(x): 0.7965	D(G(z)): 0.0471 / 0.0389
[144/200][0/29]	Loss_D: 0.2434	Loss_G: 3.6706	D(x): 0.9330	D(G(z)): 0.1462 / 0.0380
[145/200][0/29]	Loss_D: 0.4002	Loss_G: 4.5206	D(x): 0.9398	D(G(z)): 0.2421 / 0.0192
[146/200][0/29]	Loss_D: 0.2757	Loss_G: 4.2653	D(x): 0.8882	D(G(z)): 0.1311 / 0.0211
[147/200][0/29]	Loss_D: 0.3333	Loss_G: 2.7953	D(x): 0.7940	D(G(z)): 0.0721 / 0.0921
[148/200][0/29]	Loss_D: 0.1250	Loss_G: 4.7287	D(x): 0.9469	D(G(z)): 0.0618 / 0.0171
[149/200][0/29]	Loss_D: 0.9288	Loss_G: 0.9298	D(x): 0.5105	D(G(z)): 0.0313 / 0.5688
[150/200][0/29]	Loss_D: 0.1876	Loss_G: 5.1894	D(x): 0.9509	D(G(z)): 0.1150 / 0.0109
[151/200][0/29]	Loss_D: 0.2148	Loss_G: 3.2431	D(x): 0.8598	D(G(z)): 0.0510 / 0.0691
[152/200][0/29]	Loss_D: 0.1630	Loss_G: 3.8535	D(x): 0.9149	D(G(z)): 0.0662 / 0.0350
[153/200][0/29]	Loss_D: 0.2098	Loss_G: 4.6260	D(x): 0.9393	D(G(z)): 0.1280 / 0.0154
[154/200][0/29]	Loss_D: 0.5804	Loss_G: 2.6807	D(x): 0.6875	D(G(z)): 0.0369 / 0.1517
[155/200][0/29]	Loss_D: 0.2371	Loss_G: 2.8880	D(x): 0.8814	D(G(z)): 0.0907 / 0.0953
[156/200][0/29]	Loss_D: 0.2377	Loss_G: 4.2129	D(x): 0.9806	D(G(z)): 0.1749 / 0.0220
[157/200][0/29]	Loss_D: 0.1484	Loss_G: 4.1530	D(x): 0.9392	D(G(z)): 0.0781 / 0.0254
[158/200][0/29]	Loss_D: 0.1881	Loss_G: 4.0782	D(x): 0.9557	D(G(z)): 0.1228 / 0.0251
[159/200][0/29]	Loss_D: 0.1797	Loss_G: 3.4668	D(x): 0.9657	D(G(z)): 0.1279 / 0.0440
[160/200][0/29]	Loss_D: 2.2162	Loss_G: 6.8132	D(x): 0.7240	D(G(z)): 0.7128 / 0.0052
[161/200][0/29]	Loss_D: 1.3703	Loss_G: 1.8879	D(x): 0.3353	D(G(z)): 0.0240 / 0.2558
[162/200][0/29]	Loss_D: 0.3567	Loss_G: 3.3365	D(x): 0.8590	D(G(z)): 0.1585 / 0.0602
[163/200][0/29]	Loss_D: 0.3430	Loss_G: 3.2457	D(x): 0.7923	D(G(z)): 0.0803 / 0.0680
[164/200][0/29]	Loss_D: 0.2954	Loss_G: 2.6320	D(x): 0.8136	D(G(z)): 0.0702 / 0.1032
[165/200][0/29]	Loss_D: 0.2102	Loss_G: 3.9000	D(x): 0.9612	D(G(z)): 0.1435 / 0.0335
[166/200][0/29]	Loss_D: 0.2672	Loss_G: 4.0660	D(x): 0.9835	D(G(z)): 0.2034 / 0.0251
[167/200][0/29]	Loss_D: 0.2352	Loss_G: 3.8888	D(x): 0.9299	D(G(z)): 0.1354 / 0.0322
[168/200][0/29]	Loss_D: 0.1258	Loss_G: 4.1186	D(x): 0.9461	D(G(z)): 0.0625 / 0.0245
[169/200][0/29]	Loss_D: 0.8807	Loss_G: 5.6020	D(x): 0.9697	D(G(z)): 0.4804 / 0.0106
[170/200][0/29]	Loss_D: 0.2388	Loss_G: 3.9026	D(x): 0.9785	D(G(z)): 0.1809 / 0.0297
[171/200][0/29]	Loss_D: 0.6654	Loss_G: 1.2778	D(x): 0.5977	D(G(z)): 0.0779 / 0.3465
[172/200][0/29]	Loss_D: 0.2370	Loss_G: 3.7885	D(x): 0.9799	D(G(z)): 0.1799 / 0.0341
[173/200][0/29]	Loss_D: 0.1465	Loss_G: 3.5408	D(x): 0.9566	D(G(z)): 0.0924 / 0.0445
[174/200][0/29]	Loss_D: 0.1991	Loss_G: 4.0896	D(x): 0.9473	D(G(z)): 0.1264 / 0.0247
[175/200][0/29]	Loss_D: 0.1970	Loss_G: 3.6667	D(x): 0.9599	D(G(z)): 0.1346 / 0.0353
[176/200][0/29]	Loss_D: 0.1672	Loss_G: 3.7477	D(x): 0.9825	D(G(z)): 0.1290 / 0.0384
[177/200][0/29]	Loss_D: 0.1903	Loss_G: 3.2850	D(x): 0.9134	D(G(z)): 0.0887 / 0.0539
[178/200][0/29]	Loss_D: 0.0840	Loss_G: 4.7847	D(x): 0.9628	D(G(z)): 0.0410 / 0.0131
[179/200][0/29]	Loss_D: 0.1894	Loss_G: 3.3210	D(x): 0.8698	D(G(z)): 0.0397 / 0.0561
[180/200][0/29]	Loss_D: 1.6561	Loss_G: 1.6580	D(x): 0.4067	D(G(z)): 0.0230 / 0.4182
[181/200][0/29]	Loss_D: 0.3051	Loss_G: 2.9891	D(x): 0.8337	D(G(z)): 0.0913 / 0.0962
[182/200][0/29]	Loss_D: 0.2621	Loss_G: 2.7532	D(x): 0.8457	D(G(z)): 0.0768 / 0.1001
[183/200][0/29]	Loss_D: 0.2239	Loss_G: 3.1668	D(x): 0.8734	D(G(z)): 0.0717 / 0.0647
[184/200][0/29]	Loss_D: 0.2601	Loss_G: 2.3965	D(x): 0.8318	D(G(z)): 0.0618 / 0.1315
[185/200][0/29]	Loss_D: 0.3017	Loss_G: 3.2990	D(x): 0.9551	D(G(z)): 0.1996 / 0.0499
[186/200][0/29]	Loss_D: 0.0779	Loss_G: 4.4214	D(x): 0.9772	D(G(z)): 0.0512 / 0.0199
[187/200][0/29]	Loss_D: 0.1287	Loss_G: 3.6913	D(x): 0.9130	D(G(z)): 0.0333 / 0.0364
[188/200][0/29]	Loss_D: 0.3643	Loss_G: 4.8397	D(x): 0.9826	D(G(z)): 0.2531 / 0.0136
[189/200][0/29]	Loss_D: 1.7485	Loss_G: 1.5066	D(x): 0.4216	D(G(z)): 0.1263 / 0.3564
[190/200][0/29]	Loss_D: 0.5253	Loss_G: 4.1727	D(x): 0.9366	D(G(z)): 0.3164 / 0.0356
[191/200][0/29]	Loss_D: 0.1440	Loss_G: 3.8928	D(x): 0.9371	D(G(z)): 0.0708 / 0.0320
[192/200][0/29]	Loss_D: 0.2365	Loss_G: 3.8026	D(x): 0.9772	D(G(z)): 0.1764 / 0.0339
[193/200][0/29]	Loss_D: 0.5132	Loss_G: 5.4322	D(x): 0.9909	D(G(z)): 0.3396 / 0.0078
[194/200][0/29]	Loss_D: 0.1455	Loss_G: 3.5173	D(x): 0.9033	D(G(z)): 0.0365 / 0.0536
[195/200][0/29]	Loss_D: 0.2652	Loss_G: 2.9854	D(x): 0.8815	D(G(z)): 0.1184 / 0.0700
[196/200][0/29]	Loss_D: 0.2310	Loss_G: 3.1726	D(x): 0.8942	D(G(z)): 0.1041 / 0.0614
[197/200][0/29]	Loss_D: 0.2111	Loss_G: 3.2471	D(x): 0.8369	D(G(z)): 0.0248 / 0.0649
[198/200][0/29]	Loss_D: 0.1308	Loss_G: 3.0206	D(x): 0.9075	D(G(z)): 0.0291 / 0.0863
[199/200][0/29]	Loss_D: 0.1629	Loss_G: 4.1042	D(x): 0.9336	D(G(z)): 0.0794 / 0.0316
CPU times: user 28min 42s, sys: 33.5 s, total: 29min 15s
Wall time: 58min 39s

Results

Finally, lets check out how we did. Here, we will look at three different results. First, we will see how D and G’s losses changed during training. Second, we will visualize G’s output on the fixed_noise batch for every epoch. And third, we will look at a batch of real data next to a batch of fake data from G.

Loss versus training iteration

Below is a plot of D & G’s losses versus training iterations.

In [18]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

Visualization of G’s progression

Remember how we saved the generator’s output on the fixed_noise batch after every epoch of training. Now, we can visualize the training progression of G with an animation. Press the play button to start the animation.

In [19]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
Out[19]:

Real Images vs. Fake Images

Finally, lets take a look at some real images and fake images side by side.

In [20]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

Where to Go Next

We have reached the end of our journey, but there are several places you could go from here. You could:

  • Train for longer to see how good the results get
  • Modify this model to take a different dataset and possibly change the size of the images and the model architecture
  • Check out some other cool GAN projects here <https://github.com/nashory/gans-awesome-applications>__
  • Create GANs that generate music <https://deepmind.com/blog/wavenet-generative-model-raw-audio/>__
In [ ]:
 
In [ ]: